{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ONNX Inference on Spark\n",
    "\n",
    "In this example, we will train a LightGBM model, convert the model to ONNX format and use the converted model to infer some testing data on Spark.\n",
    "\n",
    "Python dependencies:\n",
    "\n",
    "- onnxmltools==1.7.0\n",
    "- lightgbm==3.2.1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "# Bootstrap Spark Session\n",
    "spark = SparkSession.builder.getOrCreate()\n",
    "\n",
    "from synapse.ml.core.platform import running_on_synapse\n",
    "\n",
    "if running_on_synapse():\n",
    "    from notebookutils.visualization import display"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = (\n",
    "    spark.read.format(\"csv\")\n",
    "    .option(\"header\", True)\n",
    "    .option(\"inferSchema\", True)\n",
    "    .load(\n",
    "        \"wasbs://publicwasb@mmlspark.blob.core.windows.net/company_bankruptcy_prediction_data.csv\"\n",
    "    )\n",
    ")\n",
    "\n",
    "display(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use LightGBM to train a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "from synapse.ml.lightgbm import LightGBMClassifier\n",
    "\n",
    "feature_cols = df.columns[1:]\n",
    "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
    "\n",
    "train_data = featurizer.transform(df)[\"Bankrupt?\", \"features\"]\n",
    "\n",
    "model = (\n",
    "    LightGBMClassifier(featuresCol=\"features\", labelCol=\"Bankrupt?\")\n",
    "    .setEarlyStoppingRound(300)\n",
    "    .setLambdaL1(0.5)\n",
    "    .setNumIterations(1000)\n",
    "    .setNumThreads(-1)\n",
    "    .setMaxDeltaStep(0.5)\n",
    "    .setNumLeaves(31)\n",
    "    .setMaxDepth(-1)\n",
    "    .setBaggingFraction(0.7)\n",
    "    .setFeatureFraction(0.7)\n",
    "    .setBaggingFreq(2)\n",
    "    .setObjective(\"binary\")\n",
    "    .setIsUnbalance(True)\n",
    "    .setMinSumHessianInLeaf(20)\n",
    "    .setMinGainToSplit(0.01)\n",
    ")\n",
    "\n",
    "model = model.fit(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Export the trained model to a LightGBM booster, convert it to ONNX format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.core.platform import running_on_binder\n",
    "\n",
    "if running_on_binder():\n",
    "    !pip install lightgbm==3.2.1\n",
    "    from IPython import get_ipython\n",
    "import lightgbm as lgb\n",
    "from lightgbm import Booster, LGBMClassifier\n",
    "\n",
    "\n",
    "def convertModel(lgbm_model: LGBMClassifier or Booster, input_size: int) -> bytes:\n",
    "    from onnxmltools.convert import convert_lightgbm\n",
    "    from onnxconverter_common.data_types import FloatTensorType\n",
    "\n",
    "    initial_types = [(\"input\", FloatTensorType([-1, input_size]))]\n",
    "    onnx_model = convert_lightgbm(\n",
    "        lgbm_model, initial_types=initial_types, target_opset=9\n",
    "    )\n",
    "    return onnx_model.SerializeToString()\n",
    "\n",
    "\n",
    "booster_model_str = model.getLightGBMBooster().modelStr().get()\n",
    "booster = lgb.Booster(model_str=booster_model_str)\n",
    "model_payload_ml = convertModel(booster, len(feature_cols))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the ONNX payload into an `ONNXModel`, and inspect the model inputs and outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.onnx import ONNXModel\n",
    "\n",
    "onnx_ml = ONNXModel().setModelPayload(model_payload_ml)\n",
    "\n",
    "print(\"Model inputs:\" + str(onnx_ml.getModelInputs()))\n",
    "print(\"Model outputs:\" + str(onnx_ml.getModelOutputs()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Map the model input to the input dataframe's column name (FeedDict), and map the output dataframe's column names to the model outputs (FetchDict)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_ml = (\n",
    "    onnx_ml.setDeviceType(\"CPU\")\n",
    "    .setFeedDict({\"input\": \"features\"})\n",
    "    .setFetchDict({\"probability\": \"probabilities\", \"prediction\": \"label\"})\n",
    "    .setMiniBatchSize(5000)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create some testing data and transform the data through the ONNX model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import VectorAssembler\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "n = 1000 * 1000\n",
    "m = 95\n",
    "test = np.random.rand(n, m)\n",
    "testPdf = pd.DataFrame(test)\n",
    "cols = list(map(str, testPdf.columns))\n",
    "testDf = spark.createDataFrame(testPdf)\n",
    "testDf = testDf.union(testDf).repartition(200)\n",
    "testDf = (\n",
    "    VectorAssembler()\n",
    "    .setInputCols(cols)\n",
    "    .setOutputCol(\"features\")\n",
    "    .transform(testDf)\n",
    "    .drop(*cols)\n",
    "    .cache()\n",
    ")\n",
    "\n",
    "display(onnx_ml.transform(testDf))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
